Skip to content

[PyTorch] [torch.compile] transformer_engine.pytorch.autocast suport inside torch.compile#2759

Open
pggPL wants to merge 14 commits intoNVIDIA:mainfrom
pggPL:torch_compile_autocast
Open

[PyTorch] [torch.compile] transformer_engine.pytorch.autocast suport inside torch.compile#2759
pggPL wants to merge 14 commits intoNVIDIA:mainfrom
pggPL:torch_compile_autocast

Conversation

@pggPL
Copy link
Copy Markdown
Collaborator

@pggPL pggPL commented Mar 13, 2026

Description

Enable torch.compile(fullgraph=True) for FP8 autocast by moving compile-visible mutable state off class attributes, avoiding tracing through support checks, and adding test.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Move mutable FP8 autocast state from direct cls attribute writes to a dataclass-backed singleton object, because torch.compile does not support writes directly to class attributes.
  • Replace lru_cache-based support checks with explicit module-level caches and mark the wrapper functions with @torch.compiler.assume_constant_result so torch.compile does not trace into check_*_support().
  • Add torch.compile coverage for FP8 autocast using a custom test module; the test is more involved because there is currently no simple TE layer that supports both FP8 and torch.compile.
  • Make DelayedScaling explicitly unsupported under torch.compile and raise a clear error for that case.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

pggPL added 3 commits March 13, 2026 14:02
Move FP8 global state onto an instance so Dynamo can trace autocast state updates, explicitly reject DelayedScaling under torch.compile, and add toy compile tests that keep TE forward/backward opaque while covering supported recipes.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Drop the standalone global dict and dataclass mutation experiments now that the torch.compile regression coverage lives in the focused autocast test file.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Use compiler constant-result wrappers for support checks and rename the module-level FP8 singleton to `_FP8_GLOBAL_STATE` for clearer semantics.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL force-pushed the torch_compile_autocast branch from 338ddae to b5d46fd Compare March 13, 2026 13:03
pggPL and others added 4 commits March 13, 2026 14:24
Restore the FP8 naming and remove extra state access helpers so the torch.compile changes stay focused on the instance-backed global state.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Drop stale availability fields from FP8GlobalState now that support checks use module-level cached results instead of manager state.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Resolve conflicts in the FP8 torch.compile changes while preserving the upstream updates in graph.py and module/base.py.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL marked this pull request as ready for review March 13, 2026 14:09
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 13, 2026

Greptile Summary

This PR enables torch.compile(fullgraph=True) for FP8 autocast by making two structural changes to transformer_engine/pytorch/quantization.py: (1) moving all mutable FP8 global state off class attributes into a FP8GlobalState dataclass singleton (quantization_state), so writes no longer hit setattr(type, ...) which torch.compile cannot trace; and (2) replacing functools.lru_cache-wrapped support-check functions with explicit module-level caches guarded by @torch.compiler.assume_constant_result, preventing the compiler from tracing into hardware-capability checks. DelayedScaling is explicitly blocked under torch.compile with a clear error. The remaining 9 files are mechanical migrations from the old class-attribute spellings to the new quantization_state.* accessor.

Key issues identified (not yet discussed in prior threads):

  • Private PyTorch internal API imports in test_torch_compile.py (lines 10-15): torch._opaque_base.OpaqueBaseMeta and torch._library.opaque_object.* are underscore-prefixed internals with no stability contract. A future PyTorch minor release could move or rename them, silently breaking every test in the file with an ImportError. The imports should be guarded with a try/except and a version-gating comment, or replaced by any equivalent public surface when available.

  • test_autocast_sanity uses torch.nn.Linear instead of ToyLinear (lines 310-329): Wrapping a plain torch.nn.Linear in te.autocast does not trigger any TE quantization. All four recipe variants (Float8CurrentScaling, Float8BlockScaling, MXFP8BlockScaling, NVFP4BlockScaling) execute identical BF16 code, so the test provides zero evidence that any recipe actually works end-to-end under torch.compile. The ToyLinear module defined in the same file (precisely for this purpose) should be used instead.

Previously-flagged issues still open: get_default_fp8_recipe() called unconditionally in autocast_enter when recipe=None (will assert under torch.compile); reset() swapping the singleton object invalidating compiled-graph guards; assert in the compile branch of fp8_graph_capturing stripped by -O; id()-based autocast keys risking unbounded dict growth; check_recipe_support not guarding the recipe=None + enabled=True + compile case; and get_fp8_recipe() falling through to get_default_fp8_recipe() when called without an active autocast context during compilation.

Confidence Score: 2/5

  • Several correctness issues in the compile path remain unaddressed, and the new test suite provides weaker coverage than it appears to.
  • Score reflects two categories of open risk: (1) The autocast_enter function unconditionally calls get_default_fp8_recipe() when recipe=None, which asserts not torch.compiler.is_compiling(). This means calling te.autocast() (no recipe) inside torch.compile — even with enabled=False — raises an opaque AssertionError rather than a user-friendly error. (2) FP8GlobalStateManager.reset() replaces the quantization_state singleton with a new object, invalidating all torch.compile guards that captured the old object's identity, causing full recompilation on every subsequent call. Additionally, the two new test issues identified in this review mean the coverage picture for "recipes work under torch.compile" is not as complete as the PR description implies. The mechanical migrations across the other 8 files look correct.
  • transformer_engine/pytorch/quantization.py (lines 317, 469, 486, 574, 591) and tests/pytorch/test_torch_compile.py (lines 10–15, 310–329) need the most attention before merging.

Important Files Changed

Filename Overview
transformer_engine/pytorch/quantization.py Core state-management refactor: replaces per-class-attribute writes with a FP8GlobalState dataclass singleton. Several correctness risks remain: get_default_fp8_recipe() is called unconditionally in autocast_enter (line 591) and will assert when recipe=None under torch.compile even for enabled=False; assert on line 470 is silently stripped by -O; reset() replaces the singleton object, breaking torch.compile guards on the old identity; and id()-based autocast keys risk unbounded dict growth.
tests/pytorch/test_torch_compile.py New test file introduces ToyLinear and two tests. test_autocast_nested_custom provides meaningful coverage; test_autocast_sanity uses torch.nn.Linear (no TE quantization), so the parametrized recipe variants are effectively untested. Imports from private PyTorch internals (torch._opaque_base, torch._library.opaque_object) add fragility against PyTorch version upgrades.
transformer_engine/pytorch/module/base.py Mechanically migrates global_amax_buffer / global_amax_history_buffer accesses to quantization_state.*. Write order changed (history before amax), which is inconsistent with ops/op.py but not a correctness issue as the two assignments are independent.
transformer_engine/pytorch/ops/op.py Mechanically migrates buffer accesses to quantization_state.*. Write order retained (amax before history), diverging from the new base.py order but not a functional issue.
transformer_engine/pytorch/module/layernorm_mlp.py Replaces removed get_autocast_state/set_autocast_state helpers with inline tuple unpacking. The inlined state tuple matches the six fields captured elsewhere; correctness is preserved.
transformer_engine/pytorch/graph.py Inlines the removed set_skip_fp8_weight_update_tensor helper. Logic is equivalent; the null-check before tensor creation matches the old helper.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["te.autocast(recipe=R, enabled=E)"] --> B{enabled?}
    B -- yes --> C["check_recipe_support(recipe)"]
    C --> D{is_compiling AND\nDelayedScaling?}
    D -- yes --> E["RuntimeError ✗"]
    D -- no --> F["save fp8_state tuple\nfrom quantization_state.*"]
    B -- no --> F
    F --> G["autocast_enter(enabled, recipe, ...)"]
    G --> H{recipe is None?}
    H -- yes --> I["get_default_fp8_recipe()\nassert not is_compiling ⚠️"]
    H -- no --> J["quantization_state.autocast_arguments[key] = recipe,group\nquantization_state.fp8_* = new values"]
    I --> J
    J --> K["yield — user code runs"]
    K --> L["finally: restore quantization_state.*\nfrom saved fp8_state tuple"]
    L --> M["autocast_exit: depth--, reduce amaxes if needed"]

    subgraph "support cache (module-level globals)"
        N["_FP8_SUPPORT / _MXFP8_SUPPORT\n_NVFP4_SUPPORT / _FP8_BLOCK_SCALING_SUPPORT"]
        O["check_fp8_support() etc.\n@assume_constant_result"] --> N
    end

    subgraph "FP8GlobalState singleton"
        P["quantization_state\n.fp8_enabled / .fp8_recipe\n.autocast_depth / .is_first_fp8_module\n.global_amax_buffer / .autocast_arguments\n..."]
    end

    J -.-> P
    L -.-> P
Loading

Reviews (5): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

pggPL and others added 2 commits March 23, 2026 12:12
Replace custom-op-based ToyLinear with a minimal version using F.linear.
Add test_autocast_sanity (parametrized over all recipes including NVFP4)
and test_autocast_nested_sanity with CustomRecipes. Both verify
fullgraph=True compilation without graph breaks.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
Verify that te.autocast(recipe=DelayedScaling(), enabled=True) raises
a clear RuntimeError when used inside torch.compile.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
pggPL and others added 2 commits March 23, 2026 12:46
Use str(recipe) for content-based recipe keying (avoids unbounded growth
when identical recipes are constructed inline) and id(group) for process
group identity (same semantics as the old hash(group) which was id-based).

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
@pggPL
Copy link
Copy Markdown
Collaborator Author

pggPL commented Mar 23, 2026

/te-ci pytorch L1

pggPL and others added 2 commits March 25, 2026 17:46
Replace custom_op-based approach with torch.library.define/impl/register_fake
using get_opaque_type_name() in the schema, which allows Inductor to properly
handle opaque value types. Add ToyQuantizer as an opaque value-type wrapper
around Float8CurrentScalingQuantizer with proper __eq__/__hash__/__fx_repr__.

test_autocast_nested_custom validates that nested te.autocast with 3 distinct
CustomRecipe instances passes the correct quantizers in both forward and backward.
test_autocast_sanity is a smoke test for all hardware-supported built-in recipes.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
Comment on lines +10 to +15
from torch._opaque_base import OpaqueBaseMeta
from torch._library.opaque_object import (
get_opaque_type_name,
register_opaque_type,
MemberType,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Private PyTorch internal APIs imported without a stability guarantee

torch._opaque_base and torch._library.opaque_object are private modules (underscore-prefixed). PyTorch treats them as implementation details — their names, locations, and interfaces can change between minor releases without deprecation warnings. A future PyTorch upgrade could silently break every test in this file with an ImportError or AttributeError.

If there is an equivalent public API (e.g., something under torch.library without the underscore), it should be used instead. If no public API exists yet, the import should at minimum be wrapped in a version guard and a comment explaining why the private API is used:

# torch._opaque_base / torch._library.opaque_object are private APIs as of PyTorch 2.x.
# These imports will need to be updated if/when a public surface is provided.
try:
    from torch._opaque_base import OpaqueBaseMeta
    from torch._library.opaque_object import get_opaque_type_name, register_opaque_type, MemberType
except ImportError:
    pytest.skip("Private torch opaque-type APIs not available in this PyTorch build")

Comment on lines +310 to +329
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("fp8_recipe", _all_recipes, ids=lambda r: type(r).__name__)
def test_autocast_sanity(fp8_recipe):
"""Smoke test: torch.nn.Linear inside a single te.autocast with each
built-in recipe. Forward + backward under torch.compile(fullgraph=True)."""
dtype = torch.bfloat16
device = "cuda"

model = torch.nn.Linear(32, 32, dtype=dtype, device=device)
inp = torch.randn(8, 32, dtype=dtype, device=device, requires_grad=True)

def fn(inp):
with te.autocast(recipe=fp8_recipe):
return model(inp)

torch._dynamo.reset()
compiled = torch.compile(fn, fullgraph=True)

out = compiled(inp)
out.sum().backward()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 test_autocast_sanity uses torch.nn.Linear — no TE quantization is exercised

torch.nn.Linear is a standard PyTorch module that is completely unaware of TE's FP8 machinery. Wrapping it in te.autocast(recipe=fp8_recipe) does not cause any quantization to occur; the forward and backward passes run entirely in BF16 regardless of which recipe is passed. This means that despite the parametrisation over Float8CurrentScaling, Float8BlockScaling, MXFP8BlockScaling, and NVFP4BlockScaling, all four test variants exercise identical code paths and provide no evidence that any of those recipes actually work under torch.compile.

ToyLinear (defined in this same file specifically for this purpose) should be used instead. It calls into TE's BasicLinear functional ops and respects the active autocast recipe, so it would actually exercise the quantization code paths:

model = ToyLinear(32, 32, dtype=dtype, device=device)

If ToyLinear cannot be driven with all four recipes today, those recipe variants should receive their own skip guards with a comment explaining the gap.

@pggPL
Copy link
Copy Markdown
Collaborator Author

pggPL commented Mar 30, 2026

/te-ci pytorch

@pggPL pggPL changed the title [PyTorch] transformer_engine.pytorch.autocast suport inside torch.compile [PyTorch] [torch.compile] transformer_engine.pytorch.autocast suport inside torch.compile Mar 30, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant